Trajectory Analysis

An example of a lasagna plot using the cell() marker. This allows one to visualise per-message information which is incredibly useful to getting a quick overview of an agent trajectory.

In this notebook the lasagna plot is used to visualise

Data Preparation

import pandas as pd
from inspect_ai.analysis.beta import messages_df, MessageColumns, EvalColumns, SampleSummary
from inspect_viz import Data, Selection
from inspect_viz.mark import cell, bar_x, text
from inspect_viz.plot import plot, legend
from inspect_viz.transform import count
from inspect_viz.layout import vconcat
from inspect_viz.interactor import toggle_y
# load the data (we cache the data in cybench_messages.csv b/c it takes a while to load)
# logs = "logs/cybench/2025-06-28T12-26-01+00-00_cybench_cHgZahEpPHqsSx8GfGjKai.eval"
# df = messages_df(logs, columns=MessageColumns + EvalColumns + SampleSummary)
# df.to_csv("cybench_messages.csv")
df = pd.read_csv("cybench_messages.csv")
# add message level data which would be interesting to visualise
df["message_len"] = df["content"].str.len()
df["refusal"] = df["content"].str.contains("apologize") # you could obviously do better than this heuristic

# merge tool calls with the message role (where there is not a tool call)
df.loc[df["tool_call_function"].isna(), "tool_call_function"] = df["role"]

# maximum number of messages
max_messages = df["order"].max()

# trim the dataframe columns (optional)
tools_df = df[[
    "id",
    "sample_id",
    "message_id",
    "role",
    "source",
    "tool_calls",
    "tool_call_id",
    "tool_call_function",
    "tool_call_error",
    "order",
    "content",
    "limit",
    "score_includes",
    "message_len",
    "refusal",
]]
# load Data
tools_data = Data.from_dataframe(tools_df)
tools_data
Viz Data (7,825 rows x 15 columns)
--------------------------------------------------------------------------------
id                                       String                                  
sample_id                                String                                  
message_id                               String                                  
role                                     String                                  
source                                   String                                  
tool_calls                               String                                  
tool_call_id                             String                                  
tool_call_function                       String                                  
tool_call_error                          String                                  
order                                    Int64                                   
content                                  String                                  
limit                                    String                                  
score_includes                           String                                  
message_len                              Float64                                 
refusal                                  Object                                  

Overview of Trajectory

  • Visualise the interaction between the assistant and the user.
  • See which tools the assistant tends to use frequently.
  • See what types of behaviour tend to result in a limit being hit.
click = Selection("single")

vconcat(
    plot(
        cell(
            data=tools_data,
            x="order",
            y="id",
            fill="tool_call_function",
            filter_by=click,
        ),
        toggle_y(target=click),
        text(
            tools_data, 
            text="limit", 
            y="id",
            dx=25,
            frame_anchor="right", 
            font_size=8, 
        ),
        width=2000,
        legend=legend("color", location="right", margin_left=100),
        x_tick_rotate=270,
        x_ticks=list(range(0, max_messages, 50)),
        x_tick_size=4,
        x_tick_padding=20,
        x_label=None,
        y_label=None,
        color_domain="fixed",
        margin_left=200,
        x_domain="fixed",
        y_domain="fixed",
    ),
    plot(
        bar_x(
            tools_data, 
            x=count(), 
            y="tool_call_function", 
            fill="tool_call_function", 
            filter_by=click,
        ),
        y_label=None,
        height=200,
        margin_left=200,
        color_domain="fixed",
    )
)

Refusals

(using a very basic heuristic)

click = Selection("single")

plot(
    cell(
        data=tools_data,
        x="order",
        y="id",
        fill="refusal",
        filter_by=click,
    ),
    toggle_y(target=click),
    width=1000,
    legend=legend("color", location="right", label="Refusal"),
    x_tick_rotate=270,
    # x_tick_size=4,
    x_tick_padding=20,
    x_ticks=list(range(0, max_messages, 50)),
    x_label=None,
    y_label=None,
    color_domain="fixed",
    margin_left=200,
    x_domain="fixed",
    y_domain="fixed",
)

Message Length

click = Selection("single")

plot(
    cell(
        data=tools_data,
        x="order",
        y="id",
        fill="message_len",
        filter_by=click,
    ),
    toggle_y(target=click),
    width=1000,
    legend=legend("color", location="right", label="Message Length"),
    x_tick_rotate=270,
    # x_tick_size=4,
    x_tick_padding=20,
    x_ticks=list(range(0, max_messages, 50)),
    x_label=None,
    y_label=None,
    color_domain="fixed",
    margin_left=200,
    x_domain="fixed",
    y_domain="fixed",
)